{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/mathmodal/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "GRU 的主要机制：\n",
    "\n",
    "**更新门**  **$z_t$**  **：**  控制前一步状态 $h_{t-1}$  对当前状态 $h_t$ 的影响。\n",
    "\n",
    "**重置门**  **$r_t$**  **：**  控制前一步状态  $h_{t-1}$  对候选状态 $\\tilde{h}_t$ 的影响。\n",
    "\n",
    "通过门控机制合成最终的隐状态  $h_t$。\n",
    "\n",
    "公式\n",
    "\n",
    "$z_t = \\sigma(W_z \\cdot [h_{t-1}, x_t])$\n",
    "\n",
    "$r_t = \\sigma(W_r \\cdot [h_{t-1}, x_t])$\n",
    "\n",
    "$\\tilde{h}$​*$t$*​$ $​*$=$*​$  $​*$\\text{tanh}(W_h$*​$ $​*$\\cdot [r_t$*​$ $​*$*$*​$ $​*$h$*​${t-1}, x_t])$\n",
    "\n",
    "$h_t = (1 - z_t) * h$​*${t-1} + z_t$*​$ $​*$*$*​$  $​*$\\tilde{h}$*​$t$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义 GRU 模型\n",
    "class GRUModel(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):\n",
    "        super(GRUModel, self).__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_dim)\n",
    "        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)\n",
    "        self.fc1 = nn.Linear(hidden_dim, 50)\n",
    "        self.fc2 = nn.Linear(50, output_dim)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.softmax = nn.Softmax(dim=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.embedding(x)\n",
    "        _, h_n = self.gru(x)\n",
    "        h_n = h_n.squeeze(0)  # Remove extra dimension from GRU output\n",
    "        x = self.relu(self.fc1(h_n))\n",
    "        x = self.softmax(self.fc2(x))\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模拟数据\n",
    "vocab_size = 50000\n",
    "embed_dim = 100\n",
    "hidden_dim = 100\n",
    "output_dim = 7\n",
    "sequence_length = 3000\n",
    "\n",
    "# 生成随机数据\n",
    "X = np.random.randint(0, vocab_size, (1000, sequence_length))  # 1000 条样本\n",
    "y = np.random.randint(0, output_dim, 1000)  # 分类标签"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据处理\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)\n",
    "lb = LabelBinarizer()\n",
    "y_train = lb.fit_transform(y_train)\n",
    "y_test = lb.transform(y_test)\n",
    "\n",
    "X_train = torch.LongTensor(X_train)\n",
    "X_test = torch.LongTensor(X_test)\n",
    "y_train = torch.FloatTensor(y_train)\n",
    "y_test = torch.FloatTensor(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建模型\n",
    "model = GRUModel(vocab_size, embed_dim, hidden_dim, output_dim)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GRUModel(\n",
       "  (embedding): Embedding(50000, 100)\n",
       "  (gru): GRU(100, 100, batch_first=True)\n",
       "  (fc1): Linear(in_features=100, out_features=50, bias=True)\n",
       "  (fc2): Linear(in_features=50, out_features=7, bias=True)\n",
       "  (relu): ReLU()\n",
       "  (softmax): Softmax(dim=1)\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 训练\n",
    "epochs = 1\n",
    "batch_size = 32\n",
    "model.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1 [00:45<?, ?it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m/root/MyCode/trashbin/Gated_Recurrent_Unit_GRU_PyTorch.ipynb Cell 9\u001b[0m line \u001b[0;36m8\n\u001b[1;32m      <a href='vscode-notebook-cell://localhost:8080/root/MyCode/trashbin/Gated_Recurrent_Unit_GRU_PyTorch.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a>\u001b[0m batch_x, batch_y \u001b[39m=\u001b[39m X_train[indices], y_train[indices]\n\u001b[1;32m      <a href='vscode-notebook-cell://localhost:8080/root/MyCode/trashbin/Gated_Recurrent_Unit_GRU_PyTorch.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a>\u001b[0m optimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[0;32m----> <a href='vscode-notebook-cell://localhost:8080/root/MyCode/trashbin/Gated_Recurrent_Unit_GRU_PyTorch.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=7'>8</a>\u001b[0m outputs \u001b[39m=\u001b[39m model(batch_x)\n\u001b[1;32m      <a href='vscode-notebook-cell://localhost:8080/root/MyCode/trashbin/Gated_Recurrent_Unit_GRU_PyTorch.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8'>9</a>\u001b[0m loss \u001b[39m=\u001b[39m criterion(outputs, torch\u001b[39m.\u001b[39mmax(batch_y, \u001b[39m1\u001b[39m)[\u001b[39m1\u001b[39m])\n\u001b[1;32m     <a href='vscode-notebook-cell://localhost:8080/root/MyCode/trashbin/Gated_Recurrent_Unit_GRU_PyTorch.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=9'>10</a>\u001b[0m loss\u001b[39m.\u001b[39mbackward()\n",
      "File \u001b[0;32m/opt/anaconda3/envs/mathmodal/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1129\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "\u001b[1;32m/root/MyCode/trashbin/Gated_Recurrent_Unit_GRU_PyTorch.ipynb Cell 9\u001b[0m line \u001b[0;36m1\n\u001b[1;32m     <a href='vscode-notebook-cell://localhost:8080/root/MyCode/trashbin/Gated_Recurrent_Unit_GRU_PyTorch.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=11'>12</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, x):\n\u001b[1;32m     <a href='vscode-notebook-cell://localhost:8080/root/MyCode/trashbin/Gated_Recurrent_Unit_GRU_PyTorch.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=12'>13</a>\u001b[0m     x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membedding(x)\n\u001b[0;32m---> <a href='vscode-notebook-cell://localhost:8080/root/MyCode/trashbin/Gated_Recurrent_Unit_GRU_PyTorch.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=13'>14</a>\u001b[0m     _, h_n \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgru(x)\n\u001b[1;32m     <a href='vscode-notebook-cell://localhost:8080/root/MyCode/trashbin/Gated_Recurrent_Unit_GRU_PyTorch.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=14'>15</a>\u001b[0m     h_n \u001b[39m=\u001b[39m h_n\u001b[39m.\u001b[39msqueeze(\u001b[39m0\u001b[39m)  \u001b[39m# Remove extra dimension from GRU output\u001b[39;00m\n\u001b[1;32m     <a href='vscode-notebook-cell://localhost:8080/root/MyCode/trashbin/Gated_Recurrent_Unit_GRU_PyTorch.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=15'>16</a>\u001b[0m     x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mrelu(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfc1(h_n))\n",
      "File \u001b[0;32m/opt/anaconda3/envs/mathmodal/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1129\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m/opt/anaconda3/envs/mathmodal/lib/python3.10/site-packages/torch/nn/modules/rnn.py:950\u001b[0m, in \u001b[0;36mGRU.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m    948\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcheck_forward_args(\u001b[39minput\u001b[39m, hx, batch_sizes)\n\u001b[1;32m    949\u001b[0m \u001b[39mif\u001b[39;00m batch_sizes \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 950\u001b[0m     result \u001b[39m=\u001b[39m _VF\u001b[39m.\u001b[39;49mgru(\u001b[39minput\u001b[39;49m, hx, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_flat_weights, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mnum_layers,\n\u001b[1;32m    951\u001b[0m                      \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdropout, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtraining, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbidirectional, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbatch_first)\n\u001b[1;32m    952\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    953\u001b[0m     result \u001b[39m=\u001b[39m _VF\u001b[39m.\u001b[39mgru(\u001b[39minput\u001b[39m, batch_sizes, hx, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_flat_weights, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias,\n\u001b[1;32m    954\u001b[0m                      \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_layers, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdropout, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbidirectional)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for epoch in tqdm(range(epochs)):\n",
    "    permutation = torch.randperm(X_train.size(0))\n",
    "    for i in range(0, X_train.size(0), batch_size):\n",
    "        indices = permutation[i:i + batch_size]\n",
    "        batch_x, batch_y = X_train[indices], y_train[indices]\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(batch_x)\n",
    "        loss = criterion(outputs, torch.max(batch_y, 1)[1])\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    outputs = model(X_test)\n",
    "    predicted = torch.argmax(outputs, dim=1)\n",
    "    y_true = torch.argmax(y_test, dim=1)\n",
    "    accuracy = (predicted == y_true).float().mean()\n",
    "    print(f\"Test Accuracy: {accuracy:.3f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mathmodal",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
